import torch
import torch.nn as nn
from torch.distributions import Categorical, Bernoulli

from math import exp
import numpy as np
from utils import to_tensor


class OptionCriticFeatures(nn.Module):
    def __init__(self,
                in_features,
                num_actions,
                num_options,
                temperature=1.0,
                eps_start=1.0,
                eps_min=0.1,
                eps_decay=int(1e6),
                eps_test=0.05,
                device='cpu',
                testing=False):

        super(OptionCriticFeatures, self).__init__()

        self.in_features = in_features
        self.num_actions = num_actions
        self.num_options = num_options
        self.device = device
        self.testing = testing

        self.temperature = temperature
        self.eps_min   = eps_min
        self.eps_start = eps_start
        self.eps_decay = eps_decay
        self.eps_test  = eps_test
        self.num_steps = 0
        
        self.features = nn.Sequential(
            nn.Linear(in_features, 32),
            nn.ReLU(),
            nn.Linear(32,32),
            nn.ReLU()
            # nn.Linear(64, 128),
            # nn.ReLU(),
        )

        self.Q            = nn.Linear(32, num_options)                 # Policy-Over-Options
        self.terminations = nn.Linear(32, num_options)                 # Option-Termination

        self.action_based_on_option1 = [nn.Linear(32, 4).to(device) for i in range(6)]
        self.action_based_on_option2 = [nn.Linear(32, 4).to(device) for i in range(6)]
        self.action_based_on_option3 = [nn.Linear(32, 4).to(device) for i in range(6)]
        self.action_based_on_option4 = [nn.Linear(32, 4).to(device) for i in range(6)]
        self.action_list = [self.action_based_on_option1,self.action_based_on_option2,self.action_based_on_option3,self.action_based_on_option4]

        self.to(device)
        self.train(not testing)

    def get_state(self, obs):
        if obs.ndim < 4:
            obs = obs.unsqueeze(0)
        obs = obs.to(self.device)
        state = self.features(obs)
        return state

    def get_Q(self, state):
        return self.Q(state)
    
    def predict_option_termination(self, state, current_option):
        termination = self.terminations(state)[:, current_option].sigmoid()
        option_termination = Bernoulli(termination).sample()
        Q = self.get_Q(state)
        next_option = Q.argmax(dim=-1)
        return bool(option_termination.item()), next_option.item()

    def get_terminations(self, state):
        return self.terminations(state).sigmoid() 

    def get_action(self, state, option):
        # action_list, , entropy_list = [], [], []
        action_list = []
        logits = 0
        entropy = 0
        # logits = state @ self.options_W + self.options_b[option]
        for i in range(6):
            probs = torch.sigmoid(self.action_list[option][i](state))
            if np.mean((np.array(probs>=0))) != 1:
                print(probs)

            m = Categorical(probs)
            action = m.sample()
            action_list.append(action)
            logits += m.log_prob(action)
            entropy += m.entropy()

        return action_list, logits, entropy
        #
        #
        # for i in range(6):
        #     logits = state @ option_W_list[i][option]  + option_b_list[i][option]
        #     # logits_list.append(logits)
        #     action_dist = logits.softmax(dim=-1)
        #
        #     action_dist = Categorical(action_dist)
        #
        #     action = action_dist.sample()
        #     action_list.append(action.item())
        #     logp = action_dist.log_prob(action)
        #     logits_list.append(logp)
        #     entropy = action_dist.entropy()
        #     entropy_list.append(entropy)
        #
        # return action_list, logits_list, entropy_list

        #  action tile
        # action_all = []
        # logp_all = []
        # entropy_all = []
        # for i in range(4):
        #     action = action_dist.sample()
        #     logp = action_dist.log_prob(action)
        #     entropy = action_dist.entropy()
        #     action_all.append(action)
        #     logp_all.append(logp)
        #     entropy_all.append(entropy)
        #
        # return action_all, mean(logp_all), mean(entropy_all)

    
    def greedy_option(self, state):
        Q = self.get_Q(state)
        return Q.argmax(dim=-1).item()

    @property
    def epsilon(self):
        if not self.testing:
            eps = self.eps_min + (self.eps_start - self.eps_min) * exp(-self.num_steps / self.eps_decay)
            self.num_steps += 1
        else:
            eps = self.eps_test
        return eps


def critic_loss(model, model_prime, data_batch, args):
    obs, options, rewards, next_obs, dones = data_batch
    batch_idx = torch.arange(len(options)).long()
    options   = torch.LongTensor(options).to(model.device)
    rewards   = torch.FloatTensor(rewards).to(model.device)
    masks     = 1 - torch.FloatTensor(dones).to(model.device)

    # The loss is the TD loss of Q and the update target, so we need to calculate Q
    states = model.get_state(to_tensor(obs)).squeeze(0)
    Q      = model.get_Q(states)
    
    # the update target contains Q_next, but for stable learning we use prime network for this
    next_states_prime = model_prime.get_state(to_tensor(next_obs)).squeeze(0)
    next_Q_prime      = model_prime.get_Q(next_states_prime) # detach?

    # Additionally, we need the beta probabilities of the next state
    next_states            = model.get_state(to_tensor(next_obs)).squeeze(0)
    next_termination_probs = model.get_terminations(next_states).detach()
    next_options_term_prob = next_termination_probs[batch_idx, options]

    # Now we can calculate the update target gt
    gt = rewards + masks * args.gamma * \
        ((1 - next_options_term_prob) * next_Q_prime[batch_idx, options] + next_options_term_prob  * next_Q_prime.max(dim=-1)[0])

    # to update Q we want to use the actual network, not the prime
    td_err = (Q[batch_idx, options] - gt.detach()).pow(2).mul(0.5).mean()
    return td_err

def actor_loss(obs, option, reward, logp, entropy, done, next_obs, model, model_prime, args):
    state = model.get_state(to_tensor(obs))
    next_state = model.get_state(to_tensor(next_obs))
    next_state_prime = model_prime.get_state(to_tensor(next_obs))

    option_term_prob = model.get_terminations(state)[:, option] # option 结束概率
    next_option_term_prob = model.get_terminations(next_state)[:, option].detach() # next_option 结束概率

    Q = model.get_Q(state).detach().squeeze().clone() # state 再经过Q linear层
    next_Q_prime = model_prime.get_Q(next_state_prime).detach().squeeze() # 下一时刻的Q

    # Target update gt
    gt = reward + (1 - done) * args.gamma * \
        ((1 - next_option_term_prob) * next_Q_prime[option] + next_option_term_prob  * next_Q_prime.max(dim=-1)[0])

    # The termination loss
    termination_loss = option_term_prob * (Q[option].detach() - Q.max(dim=-1)[0].detach() + args.termination_reg) * (1 - done)
    
    # actor-critic policy gradient with entropy regularization
    policy_loss = - logp * (gt.detach() - Q[option]) - args.entropy_reg * entropy
    # policy_loss = - logp * reward
    actor_loss = termination_loss + policy_loss
    return actor_loss
